Introduction

The purpose of the project is to predict heart failure/diseases of a patient based on their health situations.

Inspiration

Heart disease remains one of the most critical and widespread health concerns worldwide, yet detecting and preventing it early can significantly improve patients’ outcomes. Over the years, I’ve seen friends and family members go through various heart‐related scares such as hypertension. In my mind, the sooner to detect and apply prevention, the better.

An insight just came up to me that if we can predict the an individual at high risk of heart disease based on their habits and health states, we can notice and encourage them to get medication earlier. This curiosity leads me to search data online, and found this heart failure data, which is a perfect fit for this prediction topic. With the question, let’s delve into the algorithms and models!

heart
heart

Brief Overview and Outline

Now we know the goal of predicting heart diseases rate, we can map out the plan for the project. We start by acquiring and cleaning the heart‐failure data, focusing on ensuring that each predictor is properly formatted, such as the type of each predictors and missing values. Once the data is tidy, we perform an Exploratory Data Analysis (EDA) to understand relationships among variables—examining how features such as cholesterol, resting blood pressure, and exercise angina may correlate with our heart disease outcome.

factor
factor

Next, we will do data split and k-folds cross validation on the training set, which allows us to tune the parameters and compare models. In the project, we will build 7 models which are: Logistic Regression, K-Nearest Neighbors (KNN), Support Vector Machine (SVM), Random Forest, Linear discriminant analysis (LDA) and Quadratic discriminant analysis (QDA). Among these models, we will choose the best model that fitted the cross validation well by evaluate their accuracy, and apply the best one to the entire dataset. Let’s get started.

Loading Packages and Data

# load packages 
library(tidyverse)
library(dplyr)
library(tidymodels)
library(readr)
library(kknn)
library(janitor)
library(ISLR)
library(glmnet)
library(corrr)
library(corrplot)
library(randomForest)
library(vip)
library(ranger)
library(ggplot2)
library(discrim)
library(janitor)
library(patchwork) 
library(kernlab)
library(gridExtra)

dat_raw <- read_csv("heart.csv")
head(dat_raw)

The data used in this project is obtained from Kaggle, “Heart Failure Prediction Dataset”. This dataset was created by combining different existed independent datasets, updated by Kaggle user fedesoriano.

The data is composed by 918 observations with 12 attributes, but we may ditch some variables if any collinearity or redundancy emerge in analysis of later sections.

Exploratory Data Analysis

First, let’s convert all the column names into snake case for convenience and turn all the categorical variables into factors.

dat_raw <- dat_raw |> clean_names()

dat_raw$sex <- as.factor(dat_raw$sex)
dat_raw$chest_pain_type <- as.factor(dat_raw$chest_pain_type)
dat_raw$resting_ecg <- as.factor(dat_raw$resting_ecg)
dat_raw$exercise_angina <- as.factor(dat_raw$exercise_angina)
dat_raw$resting_ecg <- as.factor(dat_raw$resting_ecg)
dat_raw$st_slope <- as.factor(dat_raw$st_slope)
dat_raw$heart_disease <- as.factor(dat_raw$heart_disease)
dat_raw$fasting_bs <- as.factor(dat_raw$fasting_bs)

head(dat_raw)

Data Wrangling

The raw data does contain different levels for categorical variables, such as “M/F” for sex and “ATA/NAP” for check_pain_type. For further model building, we convert them into numerical factor levels, saved as dat.

dat_numeric <- dat_raw |>
  mutate(across(where(is.factor), ~ as.numeric(as.factor(.)))) |>
  mutate(
    sex = sex - 1,  # Convert 1/2 to 0/1
    fasting_bs = fasting_bs - 1,
    heart_disease = heart_disease - 1, 
    exercise_angina = exercise_angina - 1
  )

dat <- dat_numeric |>
  mutate(
    sex = as.factor(sex), 
    chest_pain_type = as.factor(chest_pain_type), 
    fasting_bs = as.factor(fasting_bs), 
    resting_ecg = as.factor(resting_ecg), 
    exercise_angina = as.factor(exercise_angina), 
    st_slope = as.factor(st_slope), 
    heart_disease = as.factor(heart_disease)
  )

head(dat)

Predictiors Description

Explanation of some categorical variable for interprebility:

age: age of the observation/participant

sex: gender of the observation/participant [1 = male, 0 = female]

chest_pain_type:

  • ASY (1): Asymptomatic, no chest pain or symptoms

  • ATA (2): Atypical Angina, chest pain that is not related to heart disease but may have other causes

  • NAP (3): Non-Anginal Pain, chest pain that is not related to the heart

  • TA (4): Typical Angina, chest pain caused by reduced blood flow to the heart’

resting_bp: resting blood pressure [mm Hg]

cholesterol: serum cholesterol [mm/dl]

fasting_bs:

  • 0: Fasting Blood Sugar < 120 mg/dL (Normal)

  • 1: Fasting Blood Sugar ≥ 120 mg/dL (High, Potential Risk of Diabetes)

resting_ecg:

  • LVH (1): left ventricle of the heart is enlarged

  • Normal (2): no significant abnormalities in the heart’s electrical activity

  • ST (3): ST segment or T wave abnormalities

max_hr: maximum heart rate achieved [Numeric value between 60 and 202]

exercise_angina: exercise-induced angina [1 = Yes, 0 = No]

oldpeak: ST [Numeric value measured in depression]

st_slope: Slope of the Peak Exercise ST Segment [1 = Down, 2 = Flat, 3 = Up]

heart_disease: output class [1 = heart disease, 0 = Normal]

Missing Values

Before we move on to exploring the variables in our data, we must check for any missing data because that could potentially cause issues.

sum(is.na(dat))
## [1] 0

Nice! There’s actually no missing value in the dataset. We don’t need to worry about NA for later analysis and modeling.

Visual EDA

Now we got the data tidied up, it’s time to do some visualization to understand the data and relationships between each variable more.

Age Distribution

First, let’s use a histogram to see the age distribution.

ggplot(dat, aes(x = age)) +
  geom_histogram(binwidth = 5, fill = "lightblue", color = "black", alpha = 0.7) +
  labs(title = "Age Distribution", x = "Age", y = "Count") +
  theme_minimal()

The histogram appears roughly bell-shaped, suggesting a normal-like distribution. The highest frequency occurs around ages 50-60, but the slightly right skewed means that there are more older individuals (70+ years) compared to younger individuals.

Does Chest Pain and Resting ECG Correlated to Heart Disease?

Next, we want to explore if chest pain and resting ECG are related to heart disease. Below the bar plots shows the proportion of heart disease cases (1) and non-cases (0) across Chest Pain Type and Resting ECG categories.

# chest pain plot 
plot1 <- ggplot(dat, aes(x = chest_pain_type, fill = heart_disease)) +
  geom_bar(position = "fill") + # "fill" makes it a proportion plot
  labs(title = "Chest Pain Type vs. Heart Disease", x = "Chest Pain Type", y = "Proportion") +
  theme_minimal()

# resting ecg plot 
plot2 <- ggplot(dat, aes(x = resting_ecg, fill = heart_disease)) +
  geom_bar(position = "fill") + # "fill" makes it a proportion plot
  labs(title = "Resting ECG vs. Heart Disease", x = "Resting ECG", y = "Proportion") +
  theme_minimal()

plot1 + plot2

In the left, it is evident that individuals with asymptomatic (ASY) chest pain have the highest proportion of heart disease cases, suggesting that this category is strongly associated with the condition. Typical angina (TA) also shows a relatively higher proportion of heart disease cases compared to NAP and ATA, implying a moderate risk factor. In contrast, atypical angina (ATA) and non-anginal pain (NAP) are more common in individuals without heart disease.

Similarly, the second chart highlights the relationship between resting ECG results and heart disease. Individuals with ST-T wave abnormalities (ST) have the highest proportion of heart disease cases, suggesting a strong correlation between this ECG finding and heart disease risk. Meanwhile, individuals with normal ECG results are more likely to be free of heart disease, but a significant portion still has the condition, meaning a normal ECG alone does not rule out heart disease. Additionally, individuals with left ventricular hypertrophy (LVH) have an almost equal distribution of heart disease and non-heart disease cases, which suggests that incorporating these categorical variables into a predictive model could enhance its ability to identify individuals at risk of heart disease.

Correlation Plot

After getting some awareness of relationships between couple features in the dataset, we want to dig more to find an overview.

# select only numeric values 
dat_numeric <- dat |>
  select_if(is.numeric)

# correlation matrix 
dat_nu_cor <- cor(dat_numeric)

# correlation plot 
corrplot(dat_nu_cor, method = "circle", addCoef.col = 1, number.cex = 0.7)

One relationship stands out from the correlation plot is the moderate negative correlation between age and max_hr (-0.38). This indicates that older individuals tend to have lower maximum heart rates, which aligns with medical knowledge. Moreover, it’s obvious that all five numeric variables have somewhat correlation with max_hr, suggesting that the maximum heart rates may be an important predictor of heart disease, especially in older individuals.

Another interesting relationship is between age and oldpeak (0.26). A positive correlation suggests that older individuals tend to have higher ST depression values (oldpeak), which is often linked to ischemia (reduced blood flow to the heart). Further visualization can be applied to take a closer looks about these relationships.

Maximum Heart Rate with Age Change

Since we discussed about the relationship between max_hr and age in the correlation plot, perhaps we should further explore the relationship between the two variables.

ggplot(dat, aes(x = age, y = max_hr, color = heart_disease)) +
  geom_point(alpha = 0.6) +
  geom_smooth(method = "lm", se = FALSE) +
  labs(title = "Max Heart Rate vs. Age (Colored by Heart Disease)",
       x = "Age", y = "Max Heart Rate") +
  theme_minimal()

Although the relationship is not very strong, it is apparent that there is a moderately negative relationship between age and max_hr. As age increases, max heart rate tends to decrease, which is expected based on physiology (older individuals generally have lower max HRs due to decreased cardiovascular efficiency). With the color classification between with and without heart disease, the two fitted lines suggest the individuals with heart disease tend to have lower max heart rates at the same age compared to those without heart disease.

Age vs. ST Depression

# Categorize Age into Groups
dat1 <- dat %>%
  mutate(age_group = cut(age, breaks = c(20, 30, 40, 50, 60, 70, 80),
                         labels = c("20-29", "30-39", "40-49", "50-59", "60-69", "70-79")))

# Boxplot of Oldpeak by Age Group
ggplot(dat1, aes(x = age_group, y = oldpeak, fill = age_group)) +
  geom_boxplot(alpha = 0.7) +
  labs(title = "Boxplot of Oldpeak by Age Group",
       x = "Age Group", y = "ST Depression (Oldpeak)") +
  theme_minimal() +
  theme(legend.position = "none")  # Remove legend since Age Group is on x-axis

The boxplot represents a slightly positive trend that supports the result from the correlation plot. As age increases, median oldpeak values tend to rise, indicating that older individuals generally experience higher ST depression, which is often linked to ischemia and heart stress. The variability in oldpeak also increases with age, with wider interquartile ranges and more extreme outliers in the older age groups (50+). This suggests that while some older individuals maintain normal ST depression levels, others experience severe ischemic conditions, making oldpeak a potential indicator of cardiovascular risk.

Set Up Models

With all the EDA we did above, we now have a better idea about the important variables affect heart failure. Before fitting the model, we need to do some preparation first by splitting the data, building recipe and creating k-folds cross validation.

Data Split

The first step of set up is splitting data as training and test sets. Training set will be used for data fitting and model selction, while the test set will be used to test the accuracy of the selected model. We use RMSE (root mean sqaured error) to find the best model in the training set. As long as we choose the best one, we re-fit the model to the entire training set, and calculate the RMSE on the test set to see how it perform on the new data. Since the observations in the data is below 1000, the split chosen here is 70/30, meaning 70% training data, and 30% testing data. Moreover, the split is stratified on the outcome variable, heart_disease, to ensure that the proportion of each class in the full dataset will be maintained in both the training and testing sets.

set.seed(0320)

# split dat 
dat_split <- initial_split(dat, prop = 0.7, strata = heart_disease)
dat_train <- training(dat_split)
dat_test <- testing(dat_split)

Before moving on, we need to verify if the data has been split correctly.

# test for dat
nrow(dat_train)/nrow(dat)
## [1] 0.6993464
nrow(dat_test)/nrow(dat)
## [1] 0.3006536

The proportion shows above confirmed the validity of the data split. Let’s move on!

Build Recipe

Throughout all models, we will consistently use all predictors in the dataset to do prediction. So we want to create a universal recipe to use for all models in later training.

Since there’s no missing values existed in the data, we just need to normalize the variables by centering and scaling the numeric variables of dat_recipe for calculating the RMSE. For dat_reg_recipe, except normalization, we also drop all the levels with near-zero variance for all predictors and dummy-code all the nominal predictors to viably include the categorical variables in the models.

# recipe for dat_reg
dat_recipe <- recipe(heart_disease ~ ., data = dat_train) |> 
  step_dummy(all_nominal_predictors()) |> # one-hot encode categorical variables 
  step_center(all_numeric_predictors()) |> # normalizing 
  step_scale(all_numeric_predictors()) |> 
  step_zv(all_predictors()) # remove 0 variance predictors 

K-Fold Cross Validation

In this k-folds cross validation, we choose 10 folds for the training set, which means the training data is randomly split to 10 equal-sized subsets. Each fold ends up being a test set (validation set), and 9 (k-1) subsets will be the training set for that fold. Then, whichever model we are fitting will fit the data to the training set and test on the corresponding test set. Finally, the average accuracy is calculated from these k folds to measure the overall performance of that model.

We also stratify on our outcome variable, heart_disease, to make sure each fold is balanced that maintain an equal distribution of heart_disease across folds.

# creating folds 
dat_folds <- vfold_cv(dat_train, v = 10, strata = heart_disease)

Model Building

Finally we get to build our models! To test the accuracy of each model, we choose to compare RMSE as the metric to measure the performance, which a commonly used tool to evaluate the Euclidian distance between predicted and true values. Thus, the lower RMSE is better meaning the prediction is more accurate. We will fit 6 models for the heart failure prediction: Linear Regression, Logistic Regression, Random Forest, Boosted Tree, and K-Nearest Neighbors (KNN). Additionally, we will save each result into rds files and load them back to save the future run time since each model might take a while to run.

Fitting the Models

Each model is sharing the similar steps of fitting. Step-by-step detailed explanation is presented below.

For each model, we generally have the following steps:

  1. Set up the structure of the each model with tuning parameters and mode (regression or classification) and engine of the model.
# Logistic Regression: 
log_reg_spec <- logistic_reg(penalty = tune(), mixture = tune()) |> # penalty 
  set_mode("classification") |>
  set_engine("glmnet")

# k-Nearest Neighbors: 
knn_spec <- nearest_neighbor(neighbors = tune()) |>
  set_mode("classification") |>
  set_engine("kknn")

# SVM: 
svm_spec <- svm_rbf(cost = tune(), rbf_sigma = tune()) |>
  set_mode("classification") |>
  set_engine("kernlab")

# Random Forest: 
rf_spec <- rand_forest(mtry = tune(), trees = tune(), min_n = tune()) |>
  set_mode("classification") |>
  set_engine("ranger", importance = "impurity")

# Linear Discriminant Analysis (LDA)
lda_spec <- discrim_linear() |>
  set_mode("classification") |>
  set_engine("MASS")

# Quadratic Discriminant Analysis (QDA)
qda_spec <- discrim_quad() |>
  set_mode("classification") |>
  set_engine("MASS")
  1. Set up the workflow and add recipe to each model.
# Logistic Regression:
log_reg_wflow <- 
  workflow() |>
  add_model(log_reg_spec) |>            
  add_recipe(dat_recipe)              

# k-Nearest Neighbors:
knn_wflow <- 
  workflow() |>
  add_model(knn_spec) |>             
  add_recipe(dat_recipe)

# SVM:
svm_wflow <- 
  workflow() |>
  add_model(svm_spec) |>            
  add_recipe(dat_recipe)

# Random Forest:
rf_wflow <- 
  workflow() |>
  add_model(rf_spec) |>             
  add_recipe(dat_recipe)

# LDA: 
lda_wflow <- 
  workflow() |>
  add_model(lda_spec) |>                
  add_recipe(dat_recipe) 

# QDA: 
qda_wflow <- 
  workflow() |>
  add_model(qda_spec) |>                
  add_recipe(dat_recipe)
  1. Create a tuning grid to specify the range of the parameter needed to be tuned, and also the levels of each.
# Logistic Regression (penalty, mixture): 
# Penalty on log10-scale from 10^-4 to 10^0, mixture from 0 to 1
log_reg_grid <- grid_regular(
  penalty(range = c(-4, 0)),   
  mixture(range = c(0, 1)),
  levels = 5
)

# k-Nearest Neighbors: 
# Neighbors from 1 to 20
knn_grid <- grid_regular(
  neighbors(range = c(1, 20)),
  levels = 15
)

# SVM (cost, rbf_sigma): 
# Both on log10-scale, cost ~ 10^-3 .. 10^3, sigma ~ 10^-4 .. 10^0
svm_grid <- grid_regular(
  cost(range = c(-3, 3)),          
  rbf_sigma(range = c(-4, 0)),     
  levels = 5
)

# Random Forest (mtry, trees, min_n): 
# mtry up to about #predictors or fewer
rf_grid <- grid_regular(
  mtry(range = c(1, 8)),          
  trees(range = c(200, 1000)),     
  min_n(range = c(5, 20)),
  levels = 5
)

## Both LDA and QDA do not have tunable hyperparameters, so we skip the grid creation 
  1. Tune and fit the model by specifying the workflow, k-folds cross validation folds and the tuning grid (since we are going to store the data into RDS file in the next step, this chunk only needs to be run once).
# Logistic Regression (tuning penalty, mixture): 
log_reg_res <- tune_grid(
  log_reg_wflow,        # logistic_reg spec + recipe
  resamples = dat_folds,
  grid = log_reg_grid  # defined grid 
)

# k-Nearest Neighbors (tuning neighbors): 
knn_res <- tune_grid(
  knn_wflow,            
  resamples = dat_folds,
  grid = knn_grid
)

# SVM (tuning cost, rbf_sigma): 
svm_res <- tune_grid(
  svm_wflow,
  resamples = dat_folds,
  grid = svm_grid
)

# Random Forest (tuning mtry, trees, min_n): 
rf_res <- tune_grid(
  rf_wflow,
  resamples = dat_folds,
  grid = rf_grid
)

# LDA (fit with cross validation, no tuning grid): 
lda_res <- fit_resamples(
  lda_wflow,
  resamples = dat_folds, 
  control = control_grid(save_pred = TRUE)
)

# QDA: 
qda_res <- fit_resamples(
  qda_wflow,
  resamples = dat_folds,
  control = control_grid(save_pred = TRUE)
)
  1. Save the tuned model into RDS file to avoid rerunning them.
# Logistic Regression:
write_rds(log_reg_res, file = "tuned_models/log_reg.rds")  

# k-Nearest Neighbors:
write_rds(knn_res, file = "tuned_models/knn.rds") 

# SVM:
write_rds(svm_res, file = "tuned_models/svm.rds")  

# Random Forest:
write_rds(rf_res, file = "tuned_models/rf.rds")  

# LDA:
write_rds(lda_res, file = "tuned_models/lda.rds") 

# QDA: 
write_rds(qda_res, file = "tuned_models/qda.rds")  
  1. Load the saved RDS files back to prevent rerunning the models.
# Logistic Regression: 
log_tuned <- read_rds("tuned_models/log_reg.rds")

# k-Nearest Neighbors:
knn_tuned <- read_rds("tuned_models/knn.rds")

# SVM: 
svm_tuned <- read_rds("tuned_models/svm.rds")

# Random Forest: 
rf_tuned <- read_rds("tuned_models/rf.rds")

# LDA: 
lda_tuned <- read_rds("tuned_models/lda.rds")

# QDA: 
qda_tuned <- read_rds("tuned_models/qda.rds")
  1. Collect the metrics of tuned models, filter the accuracy and arrange them in descending order to find the highest accuracy of the tuned models. Then we use slice to choose the one with the highest accuracy, saving for later comparison between each model.
# Logistic Regression:
best_log_reg_model <- collect_metrics(log_tuned) |>
  filter(.metric == "accuracy") |> 
  arrange(desc(mean)) |>
  dplyr::slice(1)

# k-Nearest Neighbors:
best_knn_model <- collect_metrics(knn_tuned) |>
  filter(.metric == "accuracy") |>
  arrange(desc(mean)) |>
  dplyr::slice(1)

# SVM:
best_svm_model <- collect_metrics(svm_tuned) |>
  filter(.metric == "accuracy") |>
  arrange(desc(mean)) |>
  dplyr::slice(1)

# Random Forest:
best_rf_model <- collect_metrics(rf_tuned) |>
  filter(.metric == "accuracy") |>
  arrange(desc(mean)) |>
  dplyr::slice(1)

# LDA:
best_lda_model <- collect_metrics(lda_tuned) |>
  filter(.metric == "accuracy") |>
  arrange(desc(mean)) |>
  dplyr::slice(1)

# QDA: 
best_qda_model <- collect_metrics(qda_tuned) |>
  filter(.metric == "accuracy") |>
  arrange(desc(mean)) |>
  dplyr::slice(1)

Model Results

After we find the best result for each model, it’s time to compare the accuracy between all models and see which model performs the best with the highest accuracy.

# Create a tibble for all the models and their accuracy rate 
model_names <- c("Logistic Regression", "K-Nearest Neighbors", 
                 "SVM", "Random Forest", "LDA", "QDA")

best_model_means <- c(best_log_reg_model$mean, best_knn_model$mean, 
                      best_svm_model$mean, best_rf_model$mean, 
                      best_lda_model$mean, best_qda_model$mean)

best_model_compare <- tibble(Model = model_names,
                                    Accuracy = best_model_means)

# Arrange models by highest accuracy (best model at the top)
best_model_compare <- best_model_compare |>
  arrange(desc(Accuracy))
best_model_compare

From the performance of cross validation result table, it is pretty surprising to see that all the models are doing a fairly good job on the data, especially KNN over-performed than all the other models. It’s concerning but we will test it later by the new data dat_test later to see its validity. Since all accuacies are over 85%, we conclude that the data likely has a mixture of linear and nonlinear patterns, but nonlinear models like KNN and Random Forest can better capture complex decision boundaries.

Model Autoplots

The autoplots provides a more straightforward visualization of the effect of tuning on the model performance at different levels. In the following plot, we will analyze the performance based on the accuracy, which a higher accuracy the better.

Logistic Regression

autoplot(log_tuned, metric = "accuracy")

Our first model is logistic regression, and the tuning parameter is the penalty term. The plot shows the impact of regularization (Lasso penalty proportion) on model accuracy. Accuracy remains fairly stable for lower levels of regularization but drops sharply at higher regularization levels, which closer to 1. This trend represents that the heavier penalty causes underfitting, while a smaller penalty might leads to a better perfermance.

Random Forest

autoplot(rf_tuned, metric = "accuracy")

For the Random Forest, we tuned the model based on number of predictors in each split mty, number of tree tree and minimal node size min_n. Accuracy appears to fluctuate slightly with different configurations, but overall, the variations are minor, suggesting that Random Forest is relatively stable across different settings. The highest accuracy is observed at smaller minimal node sizes with moderate numbers of randomly selected predictors. We can confirm this trend by checking the best model of Random Forest having 8 nodes and 1000 trees.

# check the best RF model 
best_rf_model

Support Vector Machine (SVM)

autoplot(svm_tuned, metric = "accuracy")

A SVM model is usually used in a binary classification task, and its performance is in top 3 of this heart failure dataset. The tuning result plot here display accuracy across varying cost and radial basis function (RBF) sigma values. Usually, higher accuracy is achieved when cost is moderately high and sigma is relatively small. However, for certain parameter combinations, performance degrades significantly. This suggests that SVM is highly sensitive to hyperparameter selection, and an optimal balance between cost and sigma is crucial for maximizing accuracy.

ROC of LDA & QDA

lda_tuned_preds <- collect_predictions(lda_tuned)
qda_tuned_preds <- collect_predictions(qda_tuned)

lda_plot <- lda_tuned_preds |>
  roc_curve(heart_disease, .pred_0) |>
  autoplot() + 
  ggtitle("LDA ROC Curve")

qda_plot <- qda_tuned_preds |>
  roc_curve(heart_disease, .pred_0) |>
  autoplot() + 
  ggtitle("QDA ROC Curve")

grid.arrange(lda_plot, qda_plot, ncol = 2)

Lastly, we are using LDA and QDA to fit the data. Surprisingly, these two models do not shows significant difference in accuracy though LDA is mostly used for linear data, and QDA is for non-linear data. Demonstrated by the strong classification performances, we conclude that the data we have in this case is perceiving both linear and non-linear patterns. But from the plot, we can tell that LDA ROC curve appears slightly smoother and maintains a high sensitivity across a wider range. Overall, even though both of the models are performing the worst among all the others, they still reach a accuracy that is higher than 85%, which is fairly good.

Results of the Best Model

Performance on Folds

From the previous cross validation performance, the best model in this case is KNN with 18 neighbors, which is #14. This model comes up with a mean accuracy of 88.65% and a standard error of 0.01466.

# view the best model in the KNN 
best_knn_model

Now, we want to check how the “best” model fit into the entire data and get a final test on the intact dat_test.

Fitting to Training Data

In this fitting section, we will only fit the best KNN model to the entire training data, which train the model one more time and ready for the final test!

# finalize the KNN workflow with the best parameters
final_knn_model <- finalize_workflow(knn_wflow, best_knn_model)

# fit the final KNN model to the full training data
final_knn_train <- fit(final_knn_model, dat_train)

Testing the Model

Final Accuracy

Here we test the fitted KNN model to the dat_test, which is the split test set from the entire dataset.

# make predictions on the test set
final_knn_test <- predict(final_knn_train, new_data = dat_test) |>
  bind_cols(dat_test)

# compute accuracy
knn_test_metrics <- final_knn_test |>
  metrics(truth = heart_disease, estimate = .pred_class) 

# display the results
knn_test_metrics

The accuracy we get from the test data is approximately the same as the training accuracy but slightly lower, which is a common disparity between training and test data accuracy. The 0.86223 accuracy is high enough to explain a decent amount of variation in the outcome.

ROC Representation

# compute accuracy using the class predictions
final_knn_prob <- predict(final_knn_train, new_data = dat_test, type = "prob") |>
  bind_cols(dat_test) 

# final ROC result 
final_knn_prob |>
  roc_auc(truth = heart_disease, .pred_0) 
# display the ROC graph 
final_knn_prob |>
  roc_curve(truth = heart_disease, .pred_0) |>
  autoplot()

We got a 91.97% ROC AUC and with previous 88.64% accuracy in the KNN model, which suggests that the model is very good at distinguishing between positive and negative cases. Additionally, the autoplot shows a perfect curving on the top right, mentioning the excellent performance it did in this heart failure prediciton.

Conclusion

After training and evaluating several prediction models on our heart‐disease dataset, we found that k‐Nearest Neighbors (KNN) yielded the highest performance, with an accuracy of around 86% and an ROC AUC of about 0.92 on the test set. This was somewhat surprising that often tree‐based methods like Random Forest or Boosted Trees perform strongly though Random Forest is almost yields the same accuracy as KNN in the cross validation. But, in our particular data and preprocessing setup, KNN stood out, which the reason might be the number of predictors in this dataset is not too many to handle.

Meanwhile, some of our other models (e.g., Logistic Regression and SVM) still performed reasonably well but fell short of KNN’s performance. In particular, LDA and RDA did not improve results as much as anticipated, which could be due to the size of the dataset or the nature of the predictors used.

There are a few ways we could extend or improve this analysis. First, feature engineering could be explored more deeply, such as creating additional clinical features to help distance‐based methods like KNN. Second, we might want to gather additional data or more diverse patient populations so that the models generalize better. Although the age of the population is generally normal distributed, there’s still slightly right-skewed and lack of young population. Finally, trying more advanced algorithms like neural networks, or adjusting hyperparameters such as Bayesian tuning could push performance even further.

Despite some of these limitations, this project highlights that machine‐learning classification can be a useful tool for predicting heart disease risk. By comparing multiple models, I developed a stronger intuition about how each algorithm handles categorical vs. numeric variables, the importance of factor level ordering, and the effect of data preprocessing steps. Ultimately, I gained valuable experience with tidymodels and a pipeline that can be extended to more complex medical or clinical prediction tasks in the future.